from nas_201_api import NASBench201API as API
import numpy as np
import os

# The history of benchmark files:
# [2020.02.25] NAS-Bench-201-v1_0-e61699.pth : 6219 architectures are trained once, 1621 architectures are trained twice, 7785 architectures are trained three times. `LESS` only supports CIFAR10-VALID.
# [2020.03.16] NAS-Bench-201-v1_1-096897.pth : 2225 architectures are trained once, 5439 archiitectures are trained twice, 7961 architectures are trained three times on all training sets. For the hyper-parameters with the total epochs of 12, each model is trained on CIFAR-10, CIFAR-100, ImageNet16-120 once, and is trained on CIFAR-10-VALID twice.

INPUT = 'input'
OUTPUT = 'output'
OPS = ['avg_pool_3x3', 'nor_conv_1x1', 'nor_conv_3x3', 'none', 'skip_connect']
NUM_OPS = len(OPS)
OP_SPOTS = 6
LONGEST_PATH_LENGTH = 3


class Nasbench201(object):

    def __init__(self, data_path):
        self.api = API(os.path.join(data_path, "NAS-Bench-201-v1_0-e61699.pth"))

        self.edge2index = {'1<-0': 0, '2<-0': 1, '2<-1': 2, '3<-0': 3, '3<-1': 4, '3<-2': 5}
        self.op_names = ['none', 'skip_connect', 'nor_conv_1x1', 'nor_conv_3x3', 'avg_pool_3x3']
        self.max_nodes = 4

    def dag_encoding(self, arch, deterministic, dataset='cifar10', st='tss'):

        # def loss_to_normalized_acc(loss):
        #     MEAN = 0.908192
        #     STD = 0.023961
        #     acc = 1 - loss / 100
        #     normalized = (acc - MEAN) / STD
        #     return torch.tensor(normalized, dtype=torch.float32)

        op_map = [OUTPUT, INPUT, *OPS]
        ops = self.get_op_list(arch)
        ops = [INPUT, *ops, OUTPUT]

        # ops_onehot = np.array([[i == op_map.index(op) for i in range(len(op_map))] for op in ops], dtype=np.float32)
        # # val_loss = self.get_val_loss(nasbench, deterministic=deterministic)
        # # test_loss = self.get_test_loss(nasbench)
        # matrix = np.array(
        #     [[0, 1, 1, 1, 0, 0, 0, 0],
        #      [0, 0, 0, 0, 1, 1, 0, 0],
        #      [0, 0, 0, 0, 0, 0, 1, 0],
        #      [0, 0, 0, 0, 0, 0, 0, 1],
        #      [0, 0, 0, 0, 0, 0, 1, 0],
        #      [0, 0, 0, 0, 0, 0, 0, 1],
        #      [0, 0, 0, 0, 0, 0, 0, 1],
        #      [0, 0, 0, 0, 0, 0, 0, 0]])

        ops_onehot = np.array([[i == op for i in range(7)] for op in ops], dtype=np.float32)
        matrix = np.array(
            [
                [0, 1, 1, 1, 0, 0, 0, 0],
                [0, 0, 0, 0, 1, 1, 0, 0],
                [0, 0, 0, 0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 1],
                [0, 0, 0, 0, 0, 0, 1, 0],
                [0, 0, 0, 0, 0, 0, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 1],
                [0, 0, 0, 0, 0, 0, 0, 0],
            ],
            dtype=np.float32,
        )

        # val_acc, test_acc, time_cost = self.get_details_info(arch, deterministic, dataset, st)
        val_acc, test_acc, time_cost = self.get_simul_train_epoch12_info(dataset, arch, deterministic=deterministic)
        dic = {
            'num_vertices': 8,
            'adjacency': matrix,
            'operations': ops_onehot,
            'mask': np.array([i < 8 for i in range(8)], dtype=np.float32),
            # 'val_acc': loss_to_normalized_acc(val_loss),
            # 'test_acc': loss_to_normalized_acc(test_loss)
            'val_acc': val_acc,
            'test_acc': test_acc,
            'time_cost': time_cost,
        }

        return dic

    # return the training info of 200 epochs.
    # return the training info of 12 epochs, since the above results does not include the time of training.
    def get_simul_train_epoch12_info(self, dataset, arch, deterministic=True):
        nasbench = self.api

        index = nasbench.query_index_by_arch(arch)
        if dataset == 'cifar10':
            results = nasbench.get_more_info(index, "cifar10-valid", use_12epochs_result=True, is_random=(not deterministic))
        else:
            results = nasbench.get_more_info(index, dataset, use_12epochs_result=True, is_random=(not deterministic))

        time_costs = results["train-all-time"] + results["valid-all-time"]
        return results["valid-accuracy"], results["test-accuracy"], time_costs

    def get_simul_full_train_info(self, dataset, arch, deterministic=True):
        nasbench = self.api

        index = nasbench.query_index_by_arch(arch)
        if dataset == 'cifar10':
            results = nasbench.query_by_index(index, "cifar10-valid", use_12epochs_result=False)
        else:
            results = nasbench.query_by_index(index, dataset, use_12epochs_result=False)

        val_accs, test_accs = [], []
        for key in results.keys():
            val_accs.append(results[key].get_eval('x-valid')['accuracy'])
            test_accs.append(results[key].get_eval('ori-test')['accuracy'])

        if deterministic:
            return round(np.mean(val_accs), 10), round(np.mean(test_accs), 10)
        else:
            return round(np.random.choice(val_accs), 10), round(np.random.choice(test_accs), 10)

    def get_op_list(self, string):
        # given a string, get the list of operations
        tokens = string.split('|')
        ops = [t.split('~')[0] for i, t in enumerate(tokens) if i not in [0, 2, 5, 9]]
        # ops[2], ops[3] = ops[3], ops[2]   # 调换之后效果没有不调换的好.

        return ops


if __name__ == '__main__':
    nasbench = API()